{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 实现基于用户的协同过滤； （20分） "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 导入工具包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import datetime\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "#稀疏矩阵，存储打分表\n",
    "import scipy.io as sio\n",
    "import scipy.sparse as ss\n",
    "\n",
    "#数据到文件存储\n",
    "import pickle\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user</th>\n",
       "      <th>song</th>\n",
       "      <th>play_count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOCKSGZ12A58A7CA4B</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOCVTLJ12A6310F0FD</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SODLLYS12A8C13A96B</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOEGIYH12A6D4FC0E3</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOFRQTD12A81C233C0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       user                song  play_count\n",
       "0  4e11f45d732f4861772b2906f81a7d384552ad12  SOCKSGZ12A58A7CA4B           1\n",
       "1  4e11f45d732f4861772b2906f81a7d384552ad12  SOCVTLJ12A6310F0FD           1\n",
       "2  4e11f45d732f4861772b2906f81a7d384552ad12  SODLLYS12A8C13A96B           3\n",
       "3  4e11f45d732f4861772b2906f81a7d384552ad12  SOEGIYH12A6D4FC0E3           1\n",
       "4  4e11f45d732f4861772b2906f81a7d384552ad12  SOFRQTD12A81C233C0           2"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dpath = './data/'\n",
    "df_triplet = pd.read_csv(dpath +'triplet_dataset_sub.csv', )\n",
    "df_triplet.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(37519, 3)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_triplet.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 计算打分值\n",
    "#### 用歌曲被当前用户播放量 / 用户播放总量当做打分值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user</th>\n",
       "      <th>song</th>\n",
       "      <th>play_count</th>\n",
       "      <th>total_play_count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOCKSGZ12A58A7CA4B</td>\n",
       "      <td>1</td>\n",
       "      <td>259</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOCVTLJ12A6310F0FD</td>\n",
       "      <td>1</td>\n",
       "      <td>259</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SODLLYS12A8C13A96B</td>\n",
       "      <td>3</td>\n",
       "      <td>259</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOEGIYH12A6D4FC0E3</td>\n",
       "      <td>1</td>\n",
       "      <td>259</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOFRQTD12A81C233C0</td>\n",
       "      <td>2</td>\n",
       "      <td>259</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       user                song  play_count  \\\n",
       "0  4e11f45d732f4861772b2906f81a7d384552ad12  SOCKSGZ12A58A7CA4B           1   \n",
       "1  4e11f45d732f4861772b2906f81a7d384552ad12  SOCVTLJ12A6310F0FD           1   \n",
       "2  4e11f45d732f4861772b2906f81a7d384552ad12  SODLLYS12A8C13A96B           3   \n",
       "3  4e11f45d732f4861772b2906f81a7d384552ad12  SOEGIYH12A6D4FC0E3           1   \n",
       "4  4e11f45d732f4861772b2906f81a7d384552ad12  SOFRQTD12A81C233C0           2   \n",
       "\n",
       "   total_play_count  \n",
       "0               259  \n",
       "1               259  \n",
       "2               259  \n",
       "3               259  \n",
       "4               259  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 首先统计每个用户总的播放量\n",
    "triplet_dataset_sub_song_sum_df = df_triplet[['user','play_count']].groupby('user').sum().reset_index()\n",
    "triplet_dataset_sub_song_sum_df.rename(columns={'play_count':'total_play_count'},inplace=True)\n",
    "triplet_dataset_sub_song_merged = pd.merge(df_triplet,triplet_dataset_sub_song_sum_df)\n",
    "triplet_dataset_sub_song_merged.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### 计算每个用户对每首歌曲的打分值\n",
    "triplet_dataset_sub_song_merged['fractional_play_count'] = triplet_dataset_sub_song_merged['play_count']/triplet_dataset_sub_song_merged['total_play_count']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user</th>\n",
       "      <th>song</th>\n",
       "      <th>play_count</th>\n",
       "      <th>total_play_count</th>\n",
       "      <th>fractional_play_count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOCKSGZ12A58A7CA4B</td>\n",
       "      <td>1</td>\n",
       "      <td>259</td>\n",
       "      <td>0.003861</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOCVTLJ12A6310F0FD</td>\n",
       "      <td>1</td>\n",
       "      <td>259</td>\n",
       "      <td>0.003861</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SODLLYS12A8C13A96B</td>\n",
       "      <td>3</td>\n",
       "      <td>259</td>\n",
       "      <td>0.011583</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOEGIYH12A6D4FC0E3</td>\n",
       "      <td>1</td>\n",
       "      <td>259</td>\n",
       "      <td>0.003861</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOFRQTD12A81C233C0</td>\n",
       "      <td>2</td>\n",
       "      <td>259</td>\n",
       "      <td>0.007722</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       user                song  play_count  \\\n",
       "0  4e11f45d732f4861772b2906f81a7d384552ad12  SOCKSGZ12A58A7CA4B           1   \n",
       "1  4e11f45d732f4861772b2906f81a7d384552ad12  SOCVTLJ12A6310F0FD           1   \n",
       "2  4e11f45d732f4861772b2906f81a7d384552ad12  SODLLYS12A8C13A96B           3   \n",
       "3  4e11f45d732f4861772b2906f81a7d384552ad12  SOEGIYH12A6D4FC0E3           1   \n",
       "4  4e11f45d732f4861772b2906f81a7d384552ad12  SOFRQTD12A81C233C0           2   \n",
       "\n",
       "   total_play_count  fractional_play_count  \n",
       "0               259               0.003861  \n",
       "1               259               0.003861  \n",
       "2               259               0.011583  \n",
       "3               259               0.003861  \n",
       "4               259               0.007722  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "triplet_dataset_sub_song_merged.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 将triplet_dataset_sub.csv中的数据用train_test_split分成80%数据做训练，剩下20%数据做测试。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "X_train,X_test=train_test_split(triplet_dataset_sub_song_merged,random_state=33, test_size=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 建立用户索引和歌曲索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 统计总的用户数量和歌曲数量\n",
    "unique_users=X_train.user.unique()\n",
    "unique_items=X_train.song.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user</th>\n",
       "      <th>song</th>\n",
       "      <th>play_count</th>\n",
       "      <th>total_play_count</th>\n",
       "      <th>fractional_play_count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5218</th>\n",
       "      <td>23ec7765590efc05002ccf6551cbadaa953c2069</td>\n",
       "      <td>SOVCHUK12AB017F41F</td>\n",
       "      <td>33</td>\n",
       "      <td>113</td>\n",
       "      <td>0.292035</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19545</th>\n",
       "      <td>15e7692a0539b165bbe0dad78dea78e2ba5fb37f</td>\n",
       "      <td>SOTVFIU12AC46878B7</td>\n",
       "      <td>1</td>\n",
       "      <td>177</td>\n",
       "      <td>0.005650</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30995</th>\n",
       "      <td>e851e806941435c6d5748bd64c1c55e5c0551ab1</td>\n",
       "      <td>SOAUWYT12A81C206F1</td>\n",
       "      <td>1</td>\n",
       "      <td>143</td>\n",
       "      <td>0.006993</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27555</th>\n",
       "      <td>df29d6bfb3a8055621ebf87884405d475b325d91</td>\n",
       "      <td>SOVRMZU12AB017FE90</td>\n",
       "      <td>27</td>\n",
       "      <td>473</td>\n",
       "      <td>0.057082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30038</th>\n",
       "      <td>92d2e8ff105b7b7ddc163e740921cafdbcd815bf</td>\n",
       "      <td>SOSJRJP12A6D4F826F</td>\n",
       "      <td>1</td>\n",
       "      <td>113</td>\n",
       "      <td>0.008850</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           user                song  \\\n",
       "5218   23ec7765590efc05002ccf6551cbadaa953c2069  SOVCHUK12AB017F41F   \n",
       "19545  15e7692a0539b165bbe0dad78dea78e2ba5fb37f  SOTVFIU12AC46878B7   \n",
       "30995  e851e806941435c6d5748bd64c1c55e5c0551ab1  SOAUWYT12A81C206F1   \n",
       "27555  df29d6bfb3a8055621ebf87884405d475b325d91  SOVRMZU12AB017FE90   \n",
       "30038  92d2e8ff105b7b7ddc163e740921cafdbcd815bf  SOSJRJP12A6D4F826F   \n",
       "\n",
       "       play_count  total_play_count  fractional_play_count  \n",
       "5218           33               113               0.292035  \n",
       "19545           1               177               0.005650  \n",
       "30995           1               143               0.006993  \n",
       "27555          27               473               0.057082  \n",
       "30038           1               113               0.008850  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_users=unique_users.shape[0]\n",
    "n_items=unique_items.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 建立用户和歌曲的索引表\n",
    "users_index=dict()\n",
    "items_index=dict()\n",
    "for j,u in enumerate (unique_users):\n",
    "    users_index[u]=j\n",
    "for j,i in enumerate (unique_items):\n",
    "    items_index[i]=j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "#倒排表\n",
    "#统计每个用户打过分的歌曲   / 每个歌曲被哪些用户打过分\n",
    "user_items = defaultdict(set)\n",
    "item_users = defaultdict(set)\n",
    "\n",
    "#用户-物品关系矩阵R, 稀疏矩阵，记录用户对每个歌曲的打分\n",
    "user_item_scores = ss.dok_matrix((n_users, n_items))\n",
    "\n",
    "#扫描训练数据\n",
    "for line in range(0,len( X_train.index)):  #对每条记录\n",
    "    cur_user_index = users_index [X_train.iloc[line]['user']]\n",
    "    cur_item_index = items_index [X_train.iloc[line]['song']]\n",
    "    \n",
    "    #倒排表\n",
    "    user_items[cur_user_index].add(cur_item_index)    #该用户对该歌曲歌曲进行了打分\n",
    "    item_users[cur_item_index].add(cur_user_index)    #该歌曲被该用户打分\n",
    "    \n",
    "    user_item_scores[cur_user_index, cur_item_index] = X_train.iloc[line]['fractional_play_count']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "user_item_scores = user_item_scores.tocsr()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 实现基于用户的协同过滤； （20分） "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 计算每个用户的平均打分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "users_mu=np.zeros(n_users)\n",
    "for user_index in range(n_users):\n",
    "    num_rating=0;\n",
    "    num_items=0;\n",
    "    for i in user_items[user_index]:\n",
    "        num_rating+=user_item_scores[user_index,i]\n",
    "        num_items+=1\n",
    "    users_mu[user_index]=num_rating/num_items"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 计算两个用户之间的相似度:可用两个用户播放歌曲的交集除以两个用户播放歌曲的并集表示。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def user_similarity(uid1, uid2 ):\n",
    "    union_items=0\n",
    "    in_items=0\n",
    "    for item in user_items[uid1]:  #uid1所有打过分的Item\n",
    "        if item in user_items[uid2]:  #如果uid2也对该Item打过分\n",
    "            in_items+=1  \n",
    "    union_items=len(user_items[uid1])+len(user_items[uid2])\n",
    "    \n",
    "    if (in_items==0):  #没有共同打过分的item，相似度设为0？\n",
    "        similarity=0.0  \n",
    "        return similarity  \n",
    "    else:\n",
    "        similarity=in_items/union_items\n",
    "    return similarity  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 预计算好所有用户之间的相似性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ui=0 \n",
      "ui=100 \n",
      "ui=200 \n",
      "ui=300 \n",
      "ui=400 \n",
      "ui=500 \n",
      "ui=600 \n",
      "ui=700 \n"
     ]
    }
   ],
   "source": [
    "users_similarity_matrix = np.matrix(np.zeros(shape=(n_users, n_users)), float)\n",
    "\n",
    "for ui in range(n_users):\n",
    "    users_similarity_matrix[ui,ui] = 1.0\n",
    "    \n",
    "    #打印进度条\n",
    "    if(ui % 100 == 0):\n",
    "        print (\"ui=%d \" % (ui))\n",
    "\n",
    "    for uj in range(ui+1,n_users):   \n",
    "        users_similarity_matrix[uj,ui] = user_similarity(ui, uj)\n",
    "        users_similarity_matrix[ui,uj] = users_similarity_matrix[uj,ui]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 预测用户对item的打分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def User_CF_pred(uid, iid): \n",
    "    sim_accumulate=0.0  \n",
    "    rat_acc=0.0 \n",
    "    for user_id in item_users[iid]: \n",
    "        sim = users_similarity_matrix[user_id,uid]\n",
    "            \n",
    "        if sim != 0: \n",
    "            rat_acc += sim * (user_item_scores[user_id,iid] - users_mu[user_id])   #用户user对item i的打分\n",
    "            sim_accumulate += np.abs(sim)  \n",
    "        \n",
    "    if sim_accumulate != 0:  \n",
    "        score = users_mu[uid] + rat_acc/sim_accumulate\n",
    "    else: #no similar users,return average rates of the user \n",
    "        score = users_mu[uid]\n",
    "    \n",
    "    return score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 对给定用户，推荐歌曲。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recommend(user):\n",
    "    cur_user_id = users_index[user]\n",
    "    \n",
    "    #训练集中该用户打过分的item\n",
    "    cur_user_items = user_items[cur_user_id]\n",
    "\n",
    "    #该用户对所有item的打分\n",
    "    user_items_scores = np.zeros(n_items)\n",
    "\n",
    "    #预测打分\n",
    "    for i in range(n_items):  # all items \n",
    "        if i not in cur_user_items: #训练集中没打过分\n",
    "            user_items_scores[i] = User_CF_pred(cur_user_id, i)  #预测打分\n",
    "    \n",
    "    #推荐\n",
    "    #Sort the indices of user_item_scores based upon their value，Also maintain the corresponding score\n",
    "    sort_index = sorted(((e,i) for i,e in enumerate(list(user_items_scores))), reverse=True)\n",
    "    \n",
    "    #Create a dataframe from the following\n",
    "    columns = ['item_id', 'score']\n",
    "    df = pd.DataFrame(columns=columns)\n",
    "         \n",
    "    #Fill the dataframe with top 20 (n_rec_items) item based recommendations\n",
    "    #sort_index = sort_index[0:n_rec_items]\n",
    "    #Fill the dataframe with all items based recommendations\n",
    "    for i in range(0,len(sort_index)):\n",
    "        cur_item_index = sort_index[i][1] \n",
    "        cur_item = list (items_index.keys()) [list (items_index.values()).index (cur_item_index)]\n",
    "            \n",
    "        if ~np.isnan(sort_index[i][0]) and cur_item_index not in cur_user_items:\n",
    "            df.loc[len(df)]=[cur_item, sort_index[i][0]]\n",
    "    \n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 用测试数据测试，并计算评价指标。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "52a6c7b6221f57c89dacbbd06854ca0dc415e9e6 is a new user.\n",
      "\n",
      "467e0e46181933c7e1a936e513ca55fbab4edaed is a new user.\n",
      "\n",
      "3ab78e39bddeaeb789edad041fff03050077417c is a new user.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#统计总测试的用户\n",
    "unique_users_test = X_test['user'].unique()\n",
    "\n",
    "#为每个用户推荐的item的数目\n",
    "n_rec_items = 10\n",
    "\n",
    "#性能评价参数初始化，用户计算Percison和Recall\n",
    "n_hits = 0\n",
    "n_total_rec_items = 0\n",
    "n_test_items = 0\n",
    "\n",
    "#所有被推荐商品的集合（对不同用户），用于计算覆盖度\n",
    "all_rec_items = set()\n",
    "\n",
    "#残差平方和，用与计算RMSE\n",
    "rss_test = 0.0\n",
    "\n",
    "#对每个测试用户\n",
    "for user in unique_users_test:\n",
    "    #测试集中该用户打过分的电影（用于计算评价指标的真实值）\n",
    "    if user not in users_index:   #user在训练集中没有出现过，新用户不能用协同过滤\n",
    "        print(str(user) + ' is a new user.\\n')\n",
    "        continue\n",
    "   \n",
    "    user_records_test= X_test[X_test.user == user]\n",
    "    \n",
    "    #对每个测试用户，计算该用户对训练集中未出现过的商品的打分，并基于该打分进行推荐（top n_rec_items）\n",
    "    #返回结果为DataFrame\n",
    "    rec_items = recommend(user)\n",
    "    \n",
    "    for i in range(n_rec_items):\n",
    "        item = rec_items.iloc[i]['item_id']\n",
    "        \n",
    "        if item in user_records_test['song'].values:\n",
    "            n_hits += 1\n",
    "        all_rec_items.add(item)\n",
    "    \n",
    "    #计算rmse\n",
    "    for i in range(user_records_test.shape[0]):\n",
    "        item = user_records_test.iloc[i]['song']\n",
    "        score = user_records_test.iloc[i]['fractional_play_count']\n",
    "        \n",
    "        df1 = rec_items[rec_items.item_id == item]\n",
    "        if(df1.shape[0] == 0): #item在训练集中没有出现过，新item不能被协同过滤推荐\n",
    "            print(str(item) + ' is a new item.\\n')\n",
    "            continue\n",
    "        pred_score = df1['score'].values[0]\n",
    "        rss_test += (pred_score - score)**2     #残差平方和\n",
    "    \n",
    "    #推荐的item总数\n",
    "    n_total_rec_items += n_rec_items\n",
    "    \n",
    "    #真实item的总数\n",
    "    n_test_items += user_records_test.shape[0]\n",
    "\n",
    "#Precision & Recall\n",
    "precision = n_hits / (1.0*n_total_rec_items)\n",
    "recall = n_hits / (1.0*n_test_items)\n",
    "\n",
    "#覆盖度：推荐商品占总需要推荐商品的比例\n",
    "coverage = len(all_rec_items) / (1.0* n_items)\n",
    "\n",
    "#打分的均方误差\n",
    "rmse=np.sqrt(rss_test / X_test.shape[0])  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.012225274725274725"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.011865084655379284"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.2875"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coverage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.05073683662985348"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rmse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
